{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# `Attention`机制"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-23T08:10:12.441667Z",
     "start_time": "2020-04-23T08:10:12.329033Z"
    }
   },
   "source": [
    "- 在机器翻译中，点积注意力的用法如下所示\n",
    "![](../images/scaled-dot-product-attention.png)\n",
    "\n",
    "> - 图中`Decoder`的隐藏状态向量`query` $q_t$ 与`Encoder`的隐藏状态向量序列`key` $K=[k_1, k_2, k_3, k_4]$ 每个向量做点积，每个 `key`向量对应一个分数\n",
    "- 然后`sofmax`转换为权重分布，权重值越大可以认为越“关注”该处向量的信息\n",
    "- 每个位置的 $v$ 与该处的权重相乘，然后求和得到 $q_t$ 对 $K$ 做注意力的输出向量"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `attenntion from scratch`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T15:08:13.684900Z",
     "start_time": "2020-05-05T15:08:13.259380Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAJAAAAEYCAYAAACz0n+5AAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAARuklEQVR4nO2dfZBU1ZmHnx+oZckoopTIl5rKumSjWXAl7FpEC9eIaFFiLGTxj4iu7piU7mrFZFG3EjVWbenuRiuRXZVVKljlmkQNkUREJlZSShKVAcGPBcQPjJPhIwoiRBJ25N0/+iJNc7t7pk/fPnPvvA/V1X3P/Tinh6fec+7HeVtmhuM0yqDYDXDyjQvkBOECOUG4QE4QLpAThAvkBOECFQhJYyX9QtJaSa9Jui4pP0ZSh6QNyfuwKvvPSbbZIGlOr+r060DFQdJIYKSZrZJ0JLASuAi4HNhmZndIuhEYZmZzK/Y9BugEJgKW7Hu6mW2vVadHoAJhZpvMbFXyeSewFhgNzAAWJpstpCRVJecBHWa2LZGmA5hWr04XqKBIOgk4DXgBGGFmm6AkGXBcyi6jgXfLlruSspocEtrQXuB95H5UvnCbbuvz3+ZWbr0aaC8rmm9m8w+oRGoDHgeuN7MPpQOq7VXbEuq2rxUC0f1Udyuq6deMOn9UU46TyDK/2npJh1KS52Ez+3FSvEXSSDPblIyTtqbs2gVMKVseA/yyXnu8C4uJGnjVOlwp1DwIrDWzu8pWLQb2nVXNAZ5I2f1pYKqkYclZ2tSkrCYtiUBOOr3sWvrCZODLwCuSVidlNwN3AD+SdCXwW+CSpP6JwFfM7Coz2ybpdmBFst+3zWxbvQpdoJg02R8zW17jqOekbN8JXFW2vABY0Jc6XaCIaFDTI1DLcYEiomaHoAi4QDHJvz8uUEwyGES3HBcoJvn3xwWKiUcgJ4z8++MCxcQjkBNG/v1xgWLiEcgJI//+uEAx8SvRThj598cFiomPgZww8u+PCxQTj0BOGPn3xwWKiUcgJ4z8++MCxcQjkBNG/v1xgWLiEcgJwgVywsjAH0kLgOnAVjM7NSn7ITAu2eRo4AMzm5Cy70ZgJ/Ax0GNmE+vV5wJFJKMI9H1gHvDQvgIz+7uyOr8D7Kix/9lm9l5vK3OBYpKBP2b2bJLa5eDqSsbOAv62WfV5coWISOrzK5AzgS1mtqHKegOWSVopqb3KNgfgESgmDfiQ/MfWzA9Ug0uBR2qsn2xm3ZKOAzokrTOzZ2sd0AWKSCMRpV5+oBp1HQJcDJxe49jdyftWSYuASUBNgbwLGzh8EVhnZl1pKyUNSRJzImkIpfxAr9Y7qAsUkSzGQJIeAX4DjJPUleQEAphNRfclaZSkJcniCGC5pDXAi8CTZra0Xn2F78Jm3zabIw4/gkEaxODBg7n/hvtjN2k/2ZyFXVql/PKUsm7gguTzW8D4vtZXeIEA7r7mboa2DY3djIMYEFeiJX2GUp7h0ZRO87qBxWa2NuO2FZ/8+1N7DCRpLvADSl/1RUr58wQ8kmQ87/dI4hv3fYP2/2jnp7/+aezmHECE60BNp14EuhI4xcz+r7xQ0l3Aa5SSN/Zr7rnuHoYPHc72ndv5+r1f54QRJzD+033u6rOh//nQZ+qdhe0F0hIcj0zWpSKpXVKnpM758/t8yaKpDB86HIBhRw7jzM+dybp31kVtTzkDIQJdDzwjaQP70+CfAPwZcG21nSoudlmsROO7/7QbM+OIw49g959207m+k8vOuyxKW1Lpfz70mZoCmdlSSX9O6YrkaEpfuQtYYWYft6B9QWzfuZ1vLvgmAB/v/Zgv/tUXmfQXkyK3aj/9MaL0lbpnYWa2F3i+BW1pOqOGj+LBf34wdjOqk39/BsZ1oP6KJ1dwwsi/Py5QTDxTvRPEgBhEOxmSf39coJh4BHLCyL8/LlBMihCB/IlEJwiPQBEpQgRygWKSf39coJh4BHLCyL8/LlBMPAI5YeTfHxcoJkWIQH4dKCZq4FXvkNICSVslvVpWdquk30lanbwuqLLvNEnrJb3R21k3LlBEMnqo/vvAtJTyu81sQvJaUrlS0mDgP4Hzgc8Cl0r6bL3KXKCCkaRj2dbArpOAN8zsLTPbQ2k+4Ix6O7lAEWnxtJ5rJb2cdHHDUtaPZv/MGyhNnhhd76AuUEwaGAOVz7lLXr3JJHYv8GlgArAJ+E6V1lRi9Q7sZ2ERaVWCKTPbUlbnfwM/S9msCxhbtjyGUh6EmngEikkGZ2Gp1Ugjyxa/RHriqBXAyZI+JekwSvmEFtc7tkegiGRxHShJMDUFGC6pC7gFmCJpAqUuaSNwdbLtKOABM7vAzHokXQs8DQwGFpjZa/Xqc4EiksWsjCoJplJnV5YnmEqWlwAHneLXwgWKSf4vRLtAMSnCrQwXKCb598cFionPjXfCyL8/LlBMfAzkhJF/f1ygmHgEcsLIvz8uUEw8Ajlh5N+f1gg06vy0VNOORyAnjPz70yKB4uQZ71+kBGG/Eu2EkX9/XKCY+BjICSP//rhAMfEI5ISRf39coJh4BHKC8J86cMLIvz8uUEyKcCHRZ6bGpHX5gf5d0rokucIiSUdX2XejpFeSHEKdvfkKLlBEWpgfqAM41cz+EngduKnG/mcnOYQm9qYyFygmGUSgtPxAZrbMzHqSxecpJU5oCi5QRCL97PffA09VWWfAMkkre5k2xgfRUWnAh+Q/tvw/d36S8qU3+/4L0AM8XGWTyWbWLek4oEPSuiSiVcUFikir8gMldc0BpgPnmFlq4qgk2QJmtlXSIkpp72oK5F1YTFqXH2gaMBe40Mw+qrLNEElH7vsMTCU9j9ABuEARyWIMlOQH+g0wTlKXpCuBecCRlLql1ZLuS7YdJWlfOpcRwHJJa4AXgSfNbGm9+rwLKxiN5gcys7eA8X2tzwWKiN9MdcLIvz8uUEw8Ajlh5N8fFygmHoGcMPLvjwsUE49AThj598cFikkRIpDfynCC8AgUEZ+V4QRRhC7MBYpJ/v1xgWLiEcgJI//+uEAxKUIEKvxp/E133sQZXzqD6VdMj92Ug2nRI61ZUniBLp52MQ/c+UDsZqQSaVpPUym8QJ8f/3mGHjU0djPSGcgRSNIVzWzIQEQN/OtvhESg26qtkNQuqVNS5/z5fZ7CNHAoQASqeRYm6eVqqyhNA0mlYvKbeZ7odPrjmKav1DuNHwGcB2yvKBfw60xaNJDIvz91u7CfAW1m9k7FayPwy8xb1wS+dvvXmH3NbN5+923OuuQsHn3y0dhN+oSMJham5Qc6RlKHpA3J+7Aq+85JttmQTIWuX1+VadLNxLsw2PdTBwcY8OzcZ/v8xz/rzrNqWiTpLGAX8JCZnZqU/RuwzczukHQjMMzM5lbsdwzQCUyklKVjJXC6mVX2PgdQ+NP4/kwWESgtPxAwA1iYfF4IXJSy63lAh5ltS6Tp4OBEVQfhtzJi0rox0Agz2wRgZpuS9C2VjAbeLVvuSspq4hEoIo1EoPJLJMmrV4mgetOclLK6XaxHoJg0EIEazA+0RdLIJPqMBLambNMFTClbHkMvTpQ8AkWkhVeiFwP7zqrmAE+kbPM0MFXSsOQsbWpSVhMXKCbZpPlNyw90B3CupA3AuckykiZKegDAzLYBtwMrkte3k7KaeBcWkSweqq+SHwjgnJRtO4GrypYXAAv6Up8LFJGBcCvDyZL8++MCxcQjkBNG/v1xgWLiEcgJI//+uEAx6Y+PqPYVFygm+ffHBYqJj4GcMPLvjwsUE49AThj598cFiolHICeM/PvjAsWkCBHIHyhzgvAIFJEiRCAXKCb598cFiolHICeM/PvjAsXEI5AThP/UgRNG/v3x60AxaXZ2DknjJK0ue30o6fqKbaZI2lG2zbdCvoNHoAJhZuuBCQCSBgO/AxalbPqcmTUlcbYLFJGMB9HnAG+a2TtZVuJdWEyyzdI6G3ikyrozJK2R9JSkUxpqe4ILFJGs8gNJOgy4EEhLCLkKONHMxgP3AD8J+Q6t6cJGtaSW/JFdfqDzgVVmtiVl/w/LPi+R9F+ShpvZe31vjY+BopLhGOhSqnRfko4HtpiZSZpEqRd6v9GKWiLQbaqa1H7AcIvdcnBhBv5IOoJSDqCry8q+AmBm9wEzga9K6gF2A7MtIFWvR6CIZBGBzOwj4NiKsvvKPs8D5jWrPhcoJgW4Eu0CRcRvpjph5N8fFygmnlzBCSP//rhAMfExkBNG/v1xgWLiEcgJI//+uEAx8QjkBFGEh+r9eSAnCI9AEfEuzAkj//64QDHxCOSEkX9/XKCYeARywsi/Py5QTDwCOWHk3x8XKCYegZww8u+PCxQTf6TVCSObiYUbgZ3Ax0CPmU2sWC/gu8AFwEfA5Wa2qtH6XKCIZDgGOrvGXPfzgZOT118D9ybvDeF342OSbXqXaswAHrISzwNHSxrZ6MFcoIg0O8VdggHLJK1MS/0CjAbeLVvuSsoawruwmDQQURIpysWYn6R82cdkM+uWdBzQIWmdmT1bp1ZPrpBHGhkD1csPZGbdyftWSYuASUC5QF3A2LLlMUB3nxuS4F1YTJo8BpI0RNKR+z4DU4FXKzZbDFymEn8D7DCzTY1+BY9AEcngLGwEsCg57iHA/5jZ0or8QEsoncK/Qek0/oqQCl2gAmFmbwHjU8rL8wMZcE2z6nSBIlKEWRmFE+ioMUdx0UMX0XZ8G7bXWDV/FS987wUOH3Y4M384k6NPOpoPNn7AY7Me448f/DFqW/1maj9kb89elt2wjM0vbeawtsNoX9nOmx1vMuHyCbz9zNv86s5fMXnuZL5w4xf4+Y0/j9vY/PtT/yxM0mcknSOpraJ8WnbNapxdm3ex+aXNAOzZtYffr/09R40+inEzxrFm4RoA1ixcw7iLxsVsJpDZhcSWUlMgSf8EPAH8I/CqpBllq/81y4Y1g6EnDmXkaSPpeqGLthFt7Nq8CyhJNuS4IZFbR6xbGU2lXhf2D8DpZrZL0knAY5JOMrPv0i+/zn4OHXIosx6fxdLrl7Jn557YzUmlP0aUvlJPoMFmtgvAzDZKmkJJohOpIVD55fb777+/SU3tPYMOGcSsx2fxysOvsG7ROgB2bdlF2/GlKNR2fBt/2PqHlrfrIPLvT90x0GZJE/YtJDJNB4YDn6u2k5nNN7OJZjaxvT3tfl62XPjghby39j2ev/v5T8peX/w64+eULpGMnzOe9U+sb3m7KinCGKheBLoM6CkvMLMeSpfCWx9aesHYyWMZf9l4try8hatfKiVrf+bmZ1h+x3Jm/mgmp115Gjt+u4NHL0n7HZIW0/986DMKyHLfW8x/6uCTnzo4QJn3n3u/z3/8Y888tl9pV7jrQLmiX6nQGC5QRPrjmKavuEAxyb8/LlBMPAI5YeTfHxcoJh6BnDDy748LFBOPQE4Y+ffHBYqJJ1dwwsi/Py5QTPyheicIH0Q7YeTfH5/aHJNmP1AmaaykX0haK+k1SdelbDNF0g5Jq5PXt0K+g0egmDQ/AvUAN5jZqmSO/EpJHWb2vxXbPWdm05tRoQsUkWaPgZIkCZuSzzslraWU+6dSoKbhXVjOkNQuqbPslfrQeTKL5jTghZTVZ0haI+kpSaeEtMcjUESyyA+UHLcNeBy43sw+rFi9Cjgxmap1AfATSvkSG8IjUEwymFgo6VBK8jxsZj+uXG9mH5ZN1VoCHCppeKNfwSNQRJo9BkpS+D4IrDWzu6psczywxcxM0iRKQeT9Rut0gWLS/LOwycCXgVckrU7KbgZOgE/yBM0EviqpB9gNzLaAqTkuUEQyOAtbTh0tzWweMK9ZdbpAMSnAlWgXKCJ+L8wJI//+uEAxKUIE8utAThAegSJShAjkAsUk//64QDHxCOSEkX9/XKCYeARygvBZGU4Y+ffHBYpJEbqwliTZzLqCHHGgMd0N/G1G9a+41QqB+gWS2it+W9RpAgPpVkbrM54PAAaSQE4GuEBOEANJIB//ZMCAGUQ72TCQIpCTAYUXSNI0SeslvSHpxtjtKRqF7sIkDQZeB84FuoAVwKUp2SqcBil6BJoEvGFmb5nZHuAHwIw6+zh9oOgCjQbeLVvuSsqcJlF0gdLuGxW3z45A0QXqAsaWLY8BuiO1pZAUXaAVwMmSPiXpMGA2sDhymwpFoZ8HMrMeSdcCTwODgQVm9lrkZhWKQp/GO9lT9C7MyRgXyAnCBXKCcIGcIFwgJwgXyAnCBXKCcIGcIP4ff0rRa+oKEHUAAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 108x324 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 以 decoder 的 hidden state 作为 query\n",
    "dec_hidden_state = [5, 1, 20]\n",
    "\n",
    "\n",
    "# 可视化 query 向量\n",
    "%matplotlib inline\n",
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "import seaborn as sns\n",
    "\n",
    "plt.figure(figsize=(1.5, 4.5))\n",
    "sns.heatmap(np.transpose(np.matrix(dec_hidden_state)),\n",
    "            annot=True,\n",
    "            cmap=sns.light_palette(\"purple\", as_cmap=True),\n",
    "            linewidths=1);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T15:08:18.050373Z",
     "start_time": "2020-05-05T15:08:17.968046Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAIYAAAEYCAYAAACZYo4WAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAO6UlEQVR4nO3df5BdZX3H8feX3QSWbDKAhJAQIQyKIWKzLRKpWAwQNbUoSpUxYG1TauxoECeNBDrDVGyLyYwS+xfNAgEcfskoKM3U2EiIqGkTCkZKiECKqaYJCRQwpNVNdvPtH+dscnd59p69u+fe57l3P6+ZHXbP3Xvuk8tnv8+5597veczdERnsqNgDkDQpGBKkYEiQgiFBCoYEKRgSpGC0IDNrM7Ofmtma/Oc7zewXZrYl/+oq2kd7/YcpEVwDbAMmVWz7ort/a7g7UMVoMWY2Hfgj4LbR7EfBaD1fB64FDg3a/vdm9pSZrTSzo4t20oipROfcj7ABP91rNT83diWfARZVbOp2924AM7sE2OvuT5jZ3IrfuR54ERgPdAPLgC9Xe5zGHGP8ZldDHiZpHdNK2U0egu4hbj4f+LCZfRA4BphkZne7+yfz23vM7A5gadHjaCqJykbwNTR3v97dp7v7DOATwHp3/6SZTQUwMwM+AjxdNDK9KonJqv+PLtE9ZjaZLFlbgL8suoOCEVX9guHuG4AN+fcX1Xp/BSMmS3cmT3dkEpUqRlQNO8aomYIRU+MOPmumYESlYEiQgiFBCoaE6BhDwhQMCVIwJCThqURnPiVIFSOqdCuGghFTwlOJghGVgiFBCoaEaCqRMAVDghQMCUl4KtEJrqjKbR84vNc3NjWfbmabzOx5M/ummY0v2oeCEZNZ7V/D09/U3G8FsNLd3wq8ClxVtAMFI6ryK8bgpua8yegioL/T/S6ypqOqdIwRVV2OMfqbmifmP78JeM3de/OfdwKnFO1EFSOmEUwlZrbIzP694mvRkd0daWqufJTAIxc2U6tiRFV7xai1qZmsghxnZu151ZgOFHaZq2JE1ZCm5iuBR4GP5b/2p8B3i0amYERVn5erAcuAJWa2neyY4/aiO2gqiamOJ7gGNTW/AMyp5f6qGBKkihFVuqfEWzoYPT0HuPLPr+HAwQP09fbxgXnv5fOfXRh7WEck/F5JSwdj/Phx3HXrzUw4toODB3u5YuHVXPCed9H1O7NiDy3XxMEws5nApWRny5zsNfDD7r6t6h0TYGZMOLYDgN7eXnp7+xL7I01qMANUPfg0s2XA/WT/gs3A4/n395nZdfUf3uj19fVx6eV/wbsv+ijvPu8cZr8jlWpBPd9EG7WiinEV8HZ3P1i50cxuBrYCy+s1sLK0tbXx3QduY9++/XxuyQ08t/0XnPmW02MPK9ekFYPs6rKhC1RO5Y1Xnj2s8nx+d/dQZ28ba9KkTt71zi5+9JPNsYdSoWEnuGpWVDG+ADxiZs8Dv8q3nQq8BVg81J0Gnc/3WBeAfeWV12hvb2fSpE5++9seNm56gk8vXBBlLEFpHfAMUDUY7r7WzM4kO2t2CllkdwKPu3tfA8Y3Kntf/h+uu2E5fYcO4YcOMf/9c7nwgt+PPawK6QbDGrC8ZrSKkZTsktEDk7BmVu1P/iXPNCRNLX0eI3npFgwFI650k6FgxJTwlYEVjKhUMSRIwZCQZj2PIfWmYEhQusFI97BYolIwYir5bXczO8bMNpvZz8xsq5ndmG/XSs3NpfSppAe4yN33m9k44Mdm9r38tppWalYwoio3GJ698bU//3Fc/jWiN8M0lcRUh09w5dfG2ALsBda5+6b8pppWalYwoqr9gzrVmpoB3L3P3bvIelTnmNnZZCs1zwTOBU4g60yrSlNJVKU3NVf+3mtmtgGY7+5fzTdrpeamUP6rkslmdlz+fQcwD/i5VmpuOqW/KpkK3GVmbWR/9A+4+xozW6+VmptK6a9KngJ+N7BdKzU3lXTPiCsYcaWbDAUjKgVDQvR5DAlTMCRIwZAQfUpcwlQxJEQHnxKWbjDSneQkKlWMqNKtGApGTDrGkDAFQ4IUDAnRVCJhCoYEjfVgdIQuFSqaSmQIYz0Yu/65IQ+TtGkfjD2CmuiUeEyN63bXEt7NpfRrifd3u88GuoD5ZnYeWsK72ZQbDM+Eut1rXsJbwYipAd3uwH+iJbybTf273YGzAg+sJbzT1pBu9/PQEt5Nxo6q/ava7sLd7tsYwRLeqhhRNazb/RngfjP7O+CnaAnvsaVKt3vNS3grGDHpvRIJUzAkSMGQEE0lEqZgSJCCISGaSiQs3WDolLgEqWLEpKlEwhQMCVIwJERTiYQpGBKkYEiIphIJUzAkSMGQkHRzoWDEle47EumObEwot0XRzN5sZo+a2ba8qfmafPuXzOy/K5bwLmy9V8WIqfxXJb3AX7n7k2Y2EXjCzNblt62sWGazkIIRVemL5e0Gduffv25m2xhGn2qIppKoSr8MwpE9m80g6zHpX8J7cb6E92ozO77o/gpGTCPodi9qas52a53At4EvuPs+4BbgDLJrZuwGvlY0tJabSq5fcR8b/u0Z3nRcJ2vuyJYwX/GPD/Poxq2MG9fGqdNO5CvLFjCpsyPySKEeTc1mNo4sFPe4+4P5ffZU3H4rsKbocVquYlw2fw63rRj4R3T+OWey5o5r+afbr2XG9MmsuucHkUY3WOmvSoysL3Wbu99csX1qxa99lLG4hPe5s89g54uvDNj2nnNnHv6+a9ZprP3hzxo9rLDyX5WcD/wJ8B/5xVMA/hpYYGZdZNfF2AF8pmhHIw6GmS109ztGev9Yvv29TfzhhW/o+20J7v5jwmWl5ssmjmYquXGoGyoPkLq7C6/x0TC33L2OtrY2PjzvnNhDydXvVcloVa0YZvbUUDcBU4a636ADJE/hOp8Prd3Mhn/dyp1f+yyWytvdqYwjoGgqmQJ8gOwSgJUM2FiXEdXBY5u3cev967n764vpOKbwEpcN1LzBWAN0uvuWwTfk13dKzpK//Qabt2zn1V//Lxd8/Etc/Wfz6b73EQ4c7GXh0lsAmD3rNL685PLII4WUg2HuhRdwG60kppLosktGD0zClmW1P/ldKxqSppZ7udpc0q0YCkZUCoaENPGrEqmrdIPRcu+VSDlUMWLSVCJh6RZsBSMqVQwJ0VQiYQqGBCkYEqKpRMLSDUa6r5ckKlWMmBKeSlQxompYU/MJZrYuX6l5nTrRklf6h4H7m5rPIls98XNmNgu4DngkX6n5kfznqhSMmEpekNfdd7v7k/n3r5OtoHgKcCnZCs0wzJWadYwRVf2OMQY1NU/JO+Fx991mdlLR/VUxoqp9KhlhU3PNVDFiGsGrkpE0NQN7zGxqXi2mkq37XpUqRlSNaWoGHiZboRm0UvOYNFRT83LgATO7Cvgl8PGiHSkYMZV8gqtKUzPAxbXsS8GIKt0znwpGVAqGhCT8XomCEZWCIUHpni1QMGLSVCJh6QYj3VomUalixKSpRMIUDAka68GYVrhuytikqUTCxnow7k33CWiYK0IX6Ev3eVHFiElTiYQpGBKkYEhIwlOJTolLkCpGVKoYElJyi2K2S1ttZnvN7OmKbTWv1KxgRFV6UzPAncD8wPaV7t6VfxUuB6FgRFV+MNz9MeCVwl8soGBEVZeKMRSt1Nw06rRSc4BWam4utf9dFjU1D3EfrdTcVBo0k2il5qZT/nkMM7sPmAucaGY7gb8B5jZspWYpQ/nBcPcFgc2317ofBSOmhN8rUTCiUjAkSMGQEE0lEqZgSJCCISGaSiQs3WDolLgEqWLEpKlEwhQMCVIwJERTiYQpGBKkYEiQgiEhCR9j6ASXBKlixGTp/l22bDD6DsEf33oqUyb2suqKXVz3nSls/q9jmXh0HwDLP7KHs07uiTzKdKeSlg3GNzYdxxknHmB/z5G/ymvf9xLzZ+2POKrB6vIp8dXAJcBedz8733YC8E1gBtmnxC9391er7aewlpnZTDO7OF+ysXJ7qHE2CS/ua2fD85187Pd+HXso1dWh251wU3O5KzWb2efJVty7GnjazC6tuPmm4YwyhpvWTuaL817iqEHP48r1J/KhW07jprWTOdCbQhlvWFNz6Ss1fxo4x9335yv/fsvMZrj7PwxrlBE8+twETpjQx9nTeti0o+Pw9iUXv8zkzj4O9hk3rDmJ7p8cz+L3jropfJQa9hTWvFJzUTDa3H1/vsMdZjaXLBynUeVflTfaLgJYtWoVizqH+s3yPfnLDtY/O4HHnj+dnl5jf89RLH3wZL562YsAjG93Luvax+qNhQ3f9TeC8xiVz22uO+9nLZW5hy5MengQ64El7r6lYls7sBq40t3bhvEYHusCsJt2dLB64/GsumIXe19v46SJfbjDTd+fzNHtztJ5LzduMNkFYAc+EXsfG/rJH8pJFxQ+mXl1X1Nx8PksMLdipeYN7v62avsoqhifAnorN7h7L/ApM1tVNMCULH1wKq/+XxvuMPPkHm68ZE/xnequYX8w/Ss1L2eYKzVXrRgliVYxkhKqGC/9qPYnf/IfVH0yK5uagT1kTc3fAR4ATiVfqdndqx5gtex5jObQsKZm0ErNzSTdSqpgxJTwu6sKRlQKhgQpGBKiqUTCFAwJUjAkSMGQkHRzoWDElW4yFIyY9GFgCVPFkCAFQ0J0gkvCFAwJUjAkJN1cKBhxpZsMBSMqBUNC9KpEwurS1LwDeB3oA3rd/Z0j2Y+CEVXdKsaF7j6qbioFI6aEp5J038UZE+qyvqYD/2JmTwxzsd4gVYyo6tLUfL6778o72teZ2c/zSyPURMGIaQRTSdFKze6+K//vXjN7CJgD1BwMTSUtxMwmmNnE/u+B9zOMVZlDVDGiKv3gcwrwkGWVqB24193XjmRHCkZU5QbD3V8AZpexLwUjpoRfrioYUSkYEqRgSIg+JS5hqhgSkvDBZ2Muzib9BibhN7tqf246pjUkTY0IRhLMbFE9LpTaqtI9+infiN9pHIvGUjCkBgqGBI2lYOj4ogZj5uBTajOWKobUoOWDYWbzzexZM9tuZoVLPkmmpacSM2sDngPeB+wEHgcWuPszUQfWBFq9YswBtrv7C+5+ALifbH0wKdDqwTgF+FXFzzvzbVKg1YMRel+hdefOErV6MHYCb674eTqwK9JYmkqrB+Nx4K1mdrqZjQc+QbY+mBRo6c9juHuvmS0Gvg+0AavdfWvkYTWFln65KiPX6lOJjJCCIUEKhgQpGBKkYEiQgiFBCoYEKRgS9P8jBheg4QRkwwAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 108x324 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# encoder 的 hidden state 作为 key\n",
    "annotation = [3, 12, 45]\n",
    "\n",
    "# 可视化 key 向量\n",
    "plt.figure(figsize=(1.5, 4.5))\n",
    "sns.heatmap(np.transpose(np.matrix(annotation)),\n",
    "            annot=True,\n",
    "            cmap=sns.light_palette(\"orange\", as_cmap=True),\n",
    "            linewidths=1);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T15:08:22.274913Z",
     "start_time": "2020-05-05T15:08:22.268598Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "927"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 计算 attention score\n",
    "def single_dot_attention_score(dec_hidden_state, enc_hidden_state):\n",
    "    return np.dot(dec_hidden_state, enc_hidden_state)\n",
    "\n",
    "\n",
    "single_dot_attention_score(dec_hidden_state, enc_hidden_state=annotation)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T15:08:26.460795Z",
     "start_time": "2020-05-05T15:08:26.366616Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVoAAAD4CAYAAACt8i4nAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAWrUlEQVR4nO3de5QU5ZnH8e8zMyADM+M4LCCCCipGjEaNF4xEYlAUrxjvNxYNKyZK1BVU1HiUZFfxRGPO7p4YJ4qCiuIqLkYT7xCMCoJKECUK8UoQUQR1GJhL8+wf3QZCYHqAfvsta36fc+pUd/X0W8+UUz+eeat6NHdHRETCKYldgIhI2iloRUQCU9CKiASmoBURCUxBKyISWFkR9qHbGkSktWyrR5hkrc+cs3zr99cKxQhaqHu3KLtJtIre2fWkovx3TbazcufB6iVx60iC8h2y6y8Xxq0jCSr7xK4gmOIErYhI0SSvmVHQiki6WPIuPSloRSRl1NGKiIRlyQva5PXYIiIpo45WRFImeR2tglZE0iWBUwcKWhFJmeTNiCpoRSRl1NGKiISlqQMRkdAUtCIigSloRUTC0tSBiEhoCloRkbDU0YqIhKagFREJTEErIhKWpg5ERELTR3BFRMJSRysiElrygjZ5PbaISMqooxWRdNHUgYhIaMn7RV1BKyLpoo5WRCQ0Ba2ISGAKWhGRwBS0IiJhaY5WRCS0wgWtmb0HfAlkgGZ3P8DMaoDJQC/gPeA0d1/R0jipDtqGhkbOPn80jY1NZDIZjjr8UC7+0dDYZRXVwF/1ptM2aykxp7QEpoz4gL8sbc91j3ejvrGEHtVN3HzSUiq2WRu71KK66rqbmD5jJp1rqnns4btilxNdJpPh5KH/Treunbn9V9fFLmfrFL6j/b67f7re8zHAs+4+zszG5J5f2dIAqQ7a9u3bMeE3N9GpYzlNTc2cNXwUA/ofwL57941dWlFNGPYhNR3XBek1v9ueKwd9wkG9VvPQa1Xc8cJ2XDpwecQKi++kEwZzzhk/4Mqf3hi7lESYeP+j7Np7R+pW1ccupQCCTx0MAQ7LPZ4ATCdP0Cbvzt4CMjM6dSwHoLm5mebmZiyBE+XF9u6n7Thw59UA9N+lnqcWVESuqPgO3H8ftq2qil1GIiz9+FOmvzCbU048MnYpBWKtXsxshJnNWW8ZscFgDjxlZq+s91o3d/8IILfumq+ivB2tme1BNsF75Ha6BHjU3Rfk/X4TIJPJcNI5P+GDD5dw1mnHs8/ee8QuqbgMht/TEzM4ff/POX3/z9m9ayPPvtWJI/ZYxRNvVvDRF+1iVykR3XBLLZdf/ENWpaKbZbOmDty9Fqht4Uv6u/sSM+sKPG1mf9mSklrsaM3sSuABsvH/MjA79/j+3NxE4pWWljL1/l/zxz/cy7z5b/H2ovdil1RU9//wAx654AN+e/bfuG92NbPfL+c/hyxl0uxqTqrdiVUNJbQv9dhlSiTTnn+Zmppq9uq7W+xSCqj1HW0+7r4kt14GPAIcBHxsZt0Bcutl+cbJ19EOB77p7k3/8G2Y/RJ4Axi3sTflWuwRALfffjsjzhqUr47gqior6HfAt3j+xTnsvluv2OUUTbfKDACdO2UYtEcd8/7WgeGHrGD80L8B8O7ydkxf2PamDiTr1T+/yXMzZjHjhTk0NDZSV7ea0dfezM0/Hx27tC1nhZkRNbNOQIm7f5l7fCTwM+BRYBjZ/BsGTM03Vr6gXQvsALy/wfbuudc2aoN23Kl7N18dQXy2YiVlZWVUVVawZk0DL856jfOHnRallhjqG421DhXbOPWNxgt/7ciF31vO8lWldO6UYa3DbTM6c8YBK2OXKpGMGnkuo0aeC8CsOfMYf+8jX++QBQp4Mawb8IhlpyLKgEnu/oSZzQYeNLPhwAfAqfkGyhe0lwLPmtlC4MPctp2A3YCRW1h80Sz79DPGXHcLmUwGd2fwEQP4/oB+scsqmuWryrho8g4AZNbCcXt9yYDd6pkws5pJs6sBGNS3jpP3/SJmmVFcNubnvDxnLitWfs6AI0/lJz8+l1N/cGzssqQgChO07v4OsM9Gti8HDt+sitxbnp8zsxKy8xI9yH4Hi4HZ7p5pbb2xOtpEqeidXU/SXQ+clfuZW70kbh1JUJ79h5AvF8atIwkq+0AhUvLxPVt/0eHYN4tyQua968Dd1wIzi1CLiEgBJK+ZSfUHFkSkDSrQxbBCUtCKSMqooxURCUxBKyISlv5MoohIaApaEZHAFLQiImFp6kBEJDQFrYhIYApaEZGwNHUgIhKaglZEJDAFrYhIWPpbByIioamjFREJK3k5q6AVkbRJXtIqaEUkZRS0IiJh6WKYiEho6mhFRMLSJ8NEREJT0IqIBKagFREJS1MHIiKhKWhFRAJT0IqIhJXAqYPk3dkrIrJVbDOWVoxmVmpmr5nZY7nnvc1slpktNLPJZtY+3xgKWhFJmcIGLXAJsGC95zcBt7p7H2AFMDzfAApaEUkXs9YveYeynsCxwB255wYMBB7KfckE4MR842iOVkRSpqD946+AK4DK3PPOwEp3b849Xwz0yDdIcYK2ondRdvO1cJbHriA5yneIXUFyVPaJXUF6bMbFMDMbAYxYb1Otu9fmXjsOWObur5jZYV+9ZSPD5D2p1dGKSMq0PmhzoVq7iZf7AyeY2TFAB6CKbIdbbWZlua62J7Ak336KE7RLny7KbhJt+0HZdf2HcetIgo47ZtfPHBa1jEQ4Ynp2Xfdu1DISIWG/+br7VcBVALmOdrS7n21m/wucAjwADAOm5htLF8NEJF0KeDFsE64ELjOzRWTnbO/M9wZNHYhIyhS+f3T36cD03ON3gIM25/0KWhFJlwR+MkxBKyIpo6AVEQlMQSsiEpamDkREQlPQiogEpqAVEQlLUwciIqEpaEVEAlPQioiEpakDEZHQkvcnXBS0IpIu6mhFREJLXtAmr8cWEUkZdbQiki6aOhARCU1BKyISliVvRlRBKyIpo45WRCQwBa2ISFi6GCYiEpqCVkQkMAWtiEhYmjoQEQlNQSsiEpiCVkQkLE0diIiEpqAVEQlLH8EN76px9zL9pfl03q6Sx+6+BoCbbnuEaS/Op11ZKTvt8C/cOOYcqio7Rq60uD5auowrrr2JT5evoMSM004+lmFnnRS7rKLLrHVOvulTulWXcvuPa7j63pXM/6AJd+jdtYwbh25Lpw7JO1FDaWho5OzzR9PY2EQmk+Goww/l4h8NjV3WVkpeR5u6n6iTjj6YO35x0T9s63/AHjx219X87q6r6bVjV26/76lI1cVTWlrKmMt+xB+mjGfyxP9m0uSpLPrr+7HLKrqJ01ax6/br+ourT67i0au78LtrutC9ppT7ZtRHrK742rdvx4Tf3MSjD9zG/036Nc+/OIe5ry+IXdZWss1YWhjFrIOZvWxmfzazN8xsbG57bzObZWYLzWyymbXPV1HqgvbAfXZj2w261e8e2JeyslIA9t2zN0s/WRmjtKi6dunMN/v2AaCiU0d26b0TH3/yaeSqimvpigzT5zdwyiHrfj4qyrOngLuzptFjlRaNmdGpYzkAzc3NNDc3YwnsCDeLWeuXljUAA919H2BfYLCZHQzcBNzq7n2AFcDwfANtcdCa2Xlb+t6YHv79Swzot2fsMqJavGQpC95axD577RG7lKK64aEvuPwHVZRscH5ddc9K+l+1jHc+bmboYZ3iFBdRJpNhyJkXcsigMzjk4G+zz95t6+diUzyrLve0XW5xYCDwUG77BODEfGNtTUc7dlMvmNkIM5tjZnNqa2u3YheFdds9T1BaWsIJgw6MXUo0q+pXc/HosVw9+kIqKtpOqEx7fQ01lSXstVO7f3rtxqHVPH9DV3bdvozfv7I6QnVxlZaWMvX+X/PHP9zLvPlv8fai92KXtJVaP3WwflbllhH/MJJZqZnNBZYBTwN/BVa6e3PuSxYDPfJV1OLFMDOb18J30m1T73P3WuCrhHWWPp2vjuAeeWIm01+cz923Xowl8D67Ymhqaubi0ddz/NGHc+Thh8Yup6hefaeR515fw4w3GmhocurWrGX03Su4+dztACgtMY7ZvwN3PrOKk7/Tti6UfqWqsoJ+B3yL51+cw+679YpdzpbbjLsONsiqjb2eAfY1s2rgEaDvxr4s337y3XXQDTiK7DzE+gx4Md/gSTFj1pv8dtIz3Ptfl1DeIe+8dSq5O9eMvZldeu/MeUNPiV1O0Y0aUsWoIVUAzHq7gfHPruIXw6p5f1kzO3ctw92Z9noDu3RL3Y04LfpsxUrKysqoqqxgzZoGXpz1GucPOy12WVup8I2Uu680s+nAwUC1mZXlutqewJJ878/3U/UYUOHuczd8IbfTxLls7F28PHchKz6vY8ApP+Un5x1D7X1P0djYzHmj/geAffbsxc9GnRm50uJ6Ze58pj7+DLv36c2Q0y8A4LKRP+R7h/aLXFk87nDlPStZtcZxh2/0KGPsGdvGLquoln36GWOuu4VMJoO7M/iIAXx/wNf8Z6JAv7GaWRegKRey5cARZC+ETQNOAR4AhgFT847lHvxKayKmDqLbflB2Xf9h3DqSoOOO2fUzh0UtIxGOmJ5d170btYxEqOgNhWhH545pfajtO26T+zOzb5G92FVK9nrWg+7+MzPbhWzI1gCvAee4e0NLu2lbvyeJSBtQmI7W3ecB+21k+zvAQZszloJWRNIlgRe7FbQikjIKWhGRsNTRioiEpqAVEQlMQSsiEpiCVkQkLP3hbxGR0NTRioiEpbsORERCS17QJm8yQ0QkZdTRiki6aOpARCS05P2irqAVkXRRRysiEpqCVkQkMAWtiEhYmjoQEQlNQSsiEpY6WhGR0BS0IiKBKWhFRMLS1IGISGgKWhGRwPQRXBGRsDR1ICISWvKCNnk9tohIyqijFZF0SeDUgTpaEUkZ24ylhVHMdjSzaWa2wMzeMLNLcttrzOxpM1uYW2+XtyJ336pvqRWC70BEUmPr29FFt7c+c3a7YJP7M7PuQHd3f9XMKoFXgBOBc4HP3H2cmY0BtnP3K1vajTpaEUmZwnS07v6Ru7+ae/wlsADoAQwBJuS+bALZ8G1RceZonzy4KLtJtKNmZtdfLIhbRxJU9c2udSzWHYsn+8WtIwmOmlWggQo/R2tmvYD9gFlAN3f/CLJhbGZd871fHa2IpItZqxczG2Fmc9ZbRvzzcFYBPAxc6u5fbElJuutARFKm9R2tu9cCtZscyawd2ZC9z92n5DZ/bGbdc91sd2BZvv2ooxWRdLGS1i8tDWNmwJ3AAnf/5XovPQoMyz0eBkzNV5I6WhGRjesPDAVeN7O5uW1XA+OAB81sOPABcGq+gRS0IpIyhbkY5u5/amGwwzdnLAWtiKRLAj8ZpqAVkZRR0IqIBKagFREJK8/dBDEoaEUkZdTRiogEpqAVEQkreTmroBWRtEle0ipoRSRlFLQiImHprgMRkdDU0YqIhKWP4IqIhKagFREJTEErIhJWAi+GJa8iEZGUUUcrIumii2EiIqEpaEVEAlPQioiEpakDEZHQkneNX0ErIimjjlZEJKwETh0kr8cWEUkZdbQikjLJ62hTG7SZtc7Jv/iSbtUl3H5BBWPuXcXLi5qpLM/+Rxh3dkf69kztt79RA084n04dyykpKaG0rJQpE2+JXVI0OhZfnSN1dKu29c6RDJXl2de/tudIAqcOvoZHsXUmTm9g1+1LqFuzbtsVQ8oZvF/7eEUlwITf/Ac11VWxy0iEtn4s1p0j/vdtVwzpkIJzJHkzosmrqACWrljL9DebOOU728QuRSSRsudIM6d85+seqhth1vqlSPIGrZntYWaHm1nFBtsHhytr69wwpZ7LTyinZIPjeOvjqzl+3BfcMKWexibf+JvTzIzhI6/npKGXMXnKk7GriauNH4sbpqzm8hM6bOQcWZM7R1Z/jc8R24ylOFqcOjCzi4GLgAXAnWZ2ibtPzb18A/BE4Po227T5jdRUlrDXTmXMWtj09+2XHV9OlyqjqRmunVxP7TNrGHl0ecRKi+/+O8bRrUsNyz9byXkjr2eXXj058NvfjF1WFG35WEyb30RNpbXiHGlg5NEdIla6pQoXoGY2HjgOWObue+W21QCTgV7Ae8Bp7r6ipXHydbTnA/u7+4nAYcC1ZnbJVzW0UNwIM5tjZnNqa2vzfzcF9Oo7GZ57vZGB13/OZXevYubbTYyeuIqu25ZgZrRvZ5zUrz2vf5Apal1J0K1LDQCda6oZdFg/5r2xMHJF8bTlY/HqO80893pT7hypZ+bbzZs4R5pjl7plCjt1cDew4W/vY4Bn3b0P8GzueYvyXQwrdfc6AHd/z8wOAx4ys51pIWjdvRb4KmGdJ8fnq6NgRp1QzqgTsp3qrIVNjH+ugZv/tRPLPl9L121LcHeemddEn+6pnJ7epPrVa1i71qnoVE796jW8MHMuF/7b6bHLiqKtH4vWnyOlkSvdUoU7t919hpn12mDzELKNJ8AEYDpwZUvj5AvapWa2r7vPze20zsyOA8YDe29eyXGNnriKFXVrcWCPHmWMPb1j7JKKavnylVx0xTgAMs0Zjhs8gAGHfDtyVXHoWGzc6In1650jpYw9/Ws6tbYZMwdmNgIYsd6m2lyj2JJu7v4RgLt/ZGZd8+7HfdMT3mbWE2h296Ubea2/u7+QbweA8+TBrfiylDtqZnb9xYK4dSRBVd/sWsdi3bF4sl/cOpLgqFlQiAnWT/7U+qt4Xb6bd3+5jvax9eZoV7p79Xqvr3D37Voao8WO1t0Xt/Baa0JWRKTIgt9N8LGZdc91s92BZfne0LYmKkUk/cLfR/soMCz3eBgwtYWvBRS0IpI6hbuP1szuB14CvmFmi81sODAOGGRmC4FBuectSu1HcEWkjSrg/27c3c/cxEuHb844CloRSRn9URkRkcAUtCIigSloRUTCSl7OKmhFJG2Sl7QKWhFJlwLedVAoCloRSRl1tCIigSloRUTC0v+cUUQkNAWtiEhY6mhFREJT0IqIBKagFREJS1MHIiKhKWhFRAJT0IqIhKWP4IqIhKaOVkQkrAReDEtejy0ikjLqaEUkZZLX0SpoRSRlFLQiImHprgMRkdDU0YqIhJXAuw4UtCKSMgpaEZHAkhe05u6h9xF8ByKSGlufkquXtD5zyncoSioXI2gTwcxGuHtt7DqSQMdiHR2LdXQswknefRDhjIhdQILoWKyjY7GOjkUgbSloRUSiUNCKiATWloJWc0/r6Fiso2Oxjo5FIG3mYpiISCxtqaMVEYlCQSsiEljqg9bMBpvZW2a2yMzGxK4nJjMbb2bLzGx+7FpiMrMdzWyamS0wszfM7JLYNcViZh3M7GUz+3PuWIyNXVMapXqO1sxKgbeBQcBiYDZwpru/GbWwSMxsAFAHTHT3vWLXE4uZdQe6u/urZlYJvAKc2BZ/LszMgE7uXmdm7YA/AZe4+8zIpaVK2jvag4BF7v6OuzcCDwBDItcUjbvPAD6LXUds7v6Ru7+ae/wlsADoEbeqODyrLve0XW5Jb/cVSdqDtgfw4XrPF9NGTyjZODPrBewHzIpbSTxmVmpmc4FlwNPu3maPRShpD9qN/cEI/WstAJhZBfAwcKm7fxG7nljcPePu+wI9gYPMrM1OK4WS9qBdDOy43vOewJJItUiC5OYjHwbuc/cpsetJAndfCUwHBkcuJXXSHrSzgT5m1tvM2gNnAI9Grkkiy10AuhNY4O6/jF1PTGbWxcyqc4/LgSOAv8StKn1SHbTu3gyMBJ4ke8HjQXd/I25V8ZjZ/cBLwDfMbLGZDY9dUyT9gaHAQDObm1uOiV1UJN2BaWY2j2xj8rS7Pxa5ptRJ9e1dIiJJkOqOVkQkCRS0IiKBKWhFRAJT0IqIBKagFREJTEErIhKYglZEJLD/BysaIf1Um9rVAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# encoder 的 hidden state 矩阵，即 query 与多个 key 进行注意力运算\n",
    "annotations = np.transpose([[3, 12, 45], [59, 2, 5], [1, 43, 5], [4, 3, 45.3]])\n",
    "\n",
    "# key 矩阵\n",
    "ax = sns.heatmap(annotations,\n",
    "                 annot=True,\n",
    "                 cmap=sns.light_palette(\"orange\", as_cmap=True),\n",
    "                 linewidths=1);"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-23T08:32:21.310245Z",
     "start_time": "2020-04-23T08:32:21.304151Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([927., 397., 148., 929.])"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 计算 query 与多个 key 的 attention score\n",
    "def dot_attention_score(dec_hidden_state, annotations):\n",
    "    return np.matmul(np.transpose(dec_hidden_state), annotations)\n",
    "\n",
    "\n",
    "attention_scores = dot_attention_score(dec_hidden_state, annotations)\n",
    "attention_scores"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-23T08:35:37.820391Z",
     "start_time": "2020-04-23T08:35:37.813544Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([1.19202922e-001, 7.94715151e-232, 5.76614420e-340, 8.80797078e-001],\n",
       "      dtype=float128)"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 计算 attention weights\n",
    "def softmax(x):\n",
    "    x = np.array(x, dtype=np.float128)\n",
    "    e_x = np.exp(x)\n",
    "    return e_x / e_x.sum(axis=0)\n",
    "\n",
    "\n",
    "attention_weights = softmax(attention_scores)\n",
    "attention_weights\n",
    "\n",
    "# 第一个和最后一个的注意力权重分别为 0.119 和 0.880"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-23T08:39:55.254379Z",
     "start_time": "2020-04-23T08:39:55.248386Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([[3.57608766e-001, 4.68881939e-230, 5.76614420e-340,\n",
       "        3.52318831e+000],\n",
       "       [1.43043506e+000, 1.58943030e-231, 2.47944200e-338,\n",
       "        2.64239123e+000],\n",
       "       [5.36413149e+000, 3.97357575e-231, 2.88307210e-339,\n",
       "        3.99001076e+001]], dtype=float128)"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 从 decoder 的隐藏层序列中每个抽取多大比例\n",
    "def apply_attention_scores(attention_weights, annotations):\n",
    "    return attention_weights * annotations\n",
    "\n",
    "\n",
    "applied_attention = apply_attention_scores(attention_weights, annotations)\n",
    "applied_attention"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-23T08:40:18.289763Z",
     "start_time": "2020-04-23T08:40:18.187265Z"
    }
   },
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAVoAAAD4CAYAAACt8i4nAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAgAElEQVR4nO3dd3gU5drH8e9NgGPooCGAIl26wBEUUUEDCspBpEqxgxGOgogiWF48WFHsjUNEFLuIBTsHqYI0CyIIdkRaAAFD6Nk87x+7hERCEsLO7jL8Ptc11+7OzD5zz2Rz773PNHPOISIi3ikS7QBERPxOiVZExGNKtCIiHlOiFRHxmBKtiIjHikZgGTqsQUQKyo64hdes4Dmnjzvy5RVAJBIt7FwbkcXEtBInBh93rYtuHLEgvkrwUdviwLbY9l1044gF5RpHOwLPRCbRiohETESK1MOiRCsi/mKxt+tJiVZEfEYVrYiItyz2Em3s1dgiIj6jilZEfCb2KlolWhHxlxjsOlCiFRGfib0eUSVaEfEZVbQiIt5S14GIiNeUaEVEPKZEKyLiLXUdiIh4TYlWRMRbqmhFRLymRCsi4jElWhERb6nrQETEazoFV0TEW6poRUS8FnuJNvZqbBERn1GiFRF/MSv4kGczdpyZLTKzb81suZmNCo1/0cx+M7MloaFpfiGp60BEfCZs9eMeIMk5l25mxYC5ZvZJaNow59zkgjakRCsi/hKmnWHOOQekh14WCw2uMG2p60BEfMYKPJhZspl9mW1IztGSWZyZLQE2AtOccwtDk+4zs6Vm9piZ/SO/iFTRiojPFLyidc6lACl5TA8ATc2sHPCumTUCbgM2AMVD7x0O3J3XclTRiojPFLyiLSjn3DZgFtDBObfeBe0BXgBOz+/9SrQi4i/hO+ogIVTJYmbxQDtgpZlVDo0z4BJgWX4hqetARHwmbCcsVAYmmlkcwaJ0knPuQzObYWYJoQUtAQbk19BRn2jnzFvEfWOeJjMzkx6XXETyNX1yTH/9rfd5bdIUihQpQokS8dxz51Bq16oOwMoff+Guex8jfccOihQpwuRXxvKPfxSPwlrkLRAI0K3PABIrnsC4px7IMe3+Mc+wcPE3AOzevYc/t2zly7kfFrjt9z+axnMvvgFAyfh4/nPHEOrVrc2ePXvpe82N7N23l0BGgPbt2jD431cD8Mfa9Qwdfjd//bWdBvXr8NB9t1O8WLEwrW1kzJm3iPseeprMzAA9unQ86HPjR3v27KXvgJHs3buPQCBA+6QzGZx8aY553vlwJg899TKJCRUAuKxHB3p0bheNcAsvfEcdLAWa5TI+6XDbOqoTbSAQ4O7RT/DC2DEkJibQve9Aktq0ykqkAJ0ubEvvHhcDMH3WPB54dCzPP/MgGRkBht35AGPuuY16dWuxddtfFC0aF6U1ydtLr71NrRonk75j50HTbh92fdbzl19/h+9X/nRYbZ90YmVeef5xypYpzey5C/m/ex7hrVfGUrx4MSY+9yglS8Szb18Gfa4eROuzz6DpqQ14+PFxXHVZDzp2SGLkvY8y+d2P6dOz8xGvZ6QEAgHufuAJXvjv/s/NgIM+N35UvHgxJj5zV/BvmpFBn+Q7aX1mM5o2PiXHfBe1a8XIYf2jFGU46BTcsFq6bCXVqp5I1ZOqULxYMTq2T2L6rC9yzFOqVMms57t27cZCf4R58xdTt05N6tWtBUD5cmWJi4u9RLshdROzPl9A964d8533o09m8K8ObbNej3/xDbr1GUCnHv148tkXcn3PP5s2omyZ0gA0PbUBG1I3A2BmlCwRD0BGRgYZGQHMwDnHgsXf0L5dGwC6dGrP9Jlzj2gdIy34uanyt8/NvGiH5bmcf9NA1t/Uf8K/M+xI5VvRmlk9oDNwIsGDddcB7zvnVngcW75SN26mUmLFrNeJiSewdNnBYb365nu88Mpb7NuXwcRxjwDw2+o1mEG/f9/Klq3buKh9Etde1StisRfU/WOeZtiQ69ixY1ee861dt4E169bT8vTgL525Xyzm99VrmPzqWJxzDLzxDhZ/9S0tTmtyyDYmv/sxrc8+sAM1EAjQtfd1rP5jLX0uvYQmjRuwZetflCldKqv6r5SYQOrGzWFY08hJ3biZSpWyf24SWPpd1D/OEREIBOh65XBWr9lAn+7tadLolIPm+d/MBSxe8j01qlbhtpuuonLiCVGI9AjE4LdHnhWtmQ0H3iCY+hcBi0PPXzezEd6HlzeXy0kalsu3VN9LL+GzD17llhuTGTv+FSD4gfvqm2WMue8OXpvwJJ/NmMv8hV97HvPhmDlnPhXKl6NRg7r5zvvR1Jm0b9cmqyqft+BL5s3/kksuvZYuvZL5ddVqVq1ec8j3L1j8DZPf+5hbbjxwvHZcXBxTJo1n9tS3WLpsJT/+/Bu4XLZ5DH6w8+J8sA6FFRcXx5RXHmb2B+NYuvxnfvxldY7p553TnBnvjeWDVx/lzNMbM3zU01GK9EjEXkWbX9dBP6CFc260c+6V0DCa4HFj/Q71puxnW6SkHPJY4CNWqWICG1I3Zr1OTd1MxYRDf/t2bH8en4V+IlaqmMDppzWhQvmyxMcfR+uzz2D5yh89i7Uwvl6yjBmzvyDpwl4MHXE3CxZ/wy2335frvB9/OoOOHQ700TvnSO7XhymTxjNl0nimffAqPbp05NU33qVzz/507tk/qxJd+eMv3DnqYZ59/F7Klyt7UNtlypTijOZN+XzeIsqXL0va9nQyMgJAsGujYsLxHqy9dyolJrBhQ/bPzdG3DkeqTOmSnHFaQz6f/02O8eXLlqZ48eCOzZ6d27F85a/RCO/IWJGCDxGS35IygSq5jK8cmpYr51yKc665c655cnLyoWY7Yo0b1mPV6rX8sXY9e/ft46OpM0g698wc86z6/UAVN+vzBVSreiIAZ7dqwQ8//cKuXbvJyAiw+KtvqV2zumexFsbNg69lzv/eYsYnb/Do6JG0bNGMh++/46D5fl21mrS07TRr0jBr3NlntuDt9z5hx85gl0Nq6ib+3LKVvr26ZCXfxIonsG59KoNuHslD995GjWpVs96/Zcs20tKCp3nv3r2HLxZ+Rc0aJ2NmnNG8GVM/mw3Aux9MJencs7zcDGGX6+emTatoh+W5LVv/Im37DiD0N120lJrVT8wxz8bNW7Oez/j8S2r9bfrRIfYq2vz6aIcA083sJ+CP0LiTgdrADV4GVhBFi8Yxcvgg+v97OIHMAN06X0idWjV44tkXaNTgFNqeexavvPke8xd+RdGiRSlTpjQP3jMcgLJlSnPVZT3oftlAzIzWZ5/Buee0jPIaFcwTz06gUYO6tA0luI8+mcFFHZJy/Pw9u1ULfvntd3pdETwqoUSJeMbcdzvHVyifo61nUl5i27Y0Rt3/OABxReN457VxbNz8JyP+bzSBzExcZiYdLjiX81oHv8SGDUnmpuH38Pgzz1O/bh16dLkoEqsdNkWLxjFyxGD6D7yVQGZm8HNTu0a0w/Lcxs1bGXH306G/qaND21acd3Zznhj3Bo3q16Jt6xa8/ObHzPh8MXFxcZQtU4oHRkb937wQYq8byHLrr8oxg1kRgl0FJxJcgzXA4tA5wAXh2Ln2iIL0hRKhymDXuujGEQviQz+StC0ObItt30U3jlhQrjGEI0t+1KDgV9jq+H1EsnK+Rx045zKBBRGIRUQkDGKvoj2qT1gQETlIBHdyFZQSrYj4jCpaERGPKdGKiHgrBk8+UaIVEZ9RohUR8ZgSrYiIt9R1ICLiNSVaERGPxV6ijb0je0VEjkT4bs54nJktMrNvzWy5mY0Kja9hZgvN7Ccze9PM8r3/lRKtiPhM2K7etQdIcs41AZoCHcysJfAg8Jhzrg6wlTwuGbufEq2I+Ex4Eq0LSg+9LBYaHJAETA6Nn0jwluN5UqIVEX85jAt/Z79JQWjIcQFtM4szsyXARmAa8AuwzTmXEZplDcErG+ZJO8NExGcKvjPMOZcCHPI2MKHLwTY1s3LAu0D93GbLbzmqaEXEXzy4wYJzbhswC2gJlDOz/UXqSQRvWJsnJVoR8ZnwZFozSwhVsphZPNAOWAHMBLqHZrsSmJJfROo6EBGfCdtxtJWBiWYWR7AoneSc+9DMvgfeMLN7gW+A5/NrSIlWRPwlTBf+ds4tBZrlMv5Xgrf3KjAlWhHxmdg7M0yJVkT8RReVERHxmhKtiIjHlGhFRLylrgMREa8p0YqIeEyJVkTEW+o6EBHxmhKtiIjHlGhFRLylrgMREa/F3kUJI5NoS+R7AfJjR3yVaEcQO7QtDijXONoR+IcqWhERrx2riTb914gsJqaVqhl83JXvxdj9b38lq21xYFukrYxuHLGgTL1oR+AZVbQi4i/qOhAR8dqxujNMRCRSYrCijb3ULyJyRMJ2c8aqZjbTzFaY2XIzuzE0/j9mttbMloSGi/KLSBWtiPhM2CraDOBm59zXZlYa+MrMpoWmPeace7igDSnRioi/hKnrwDm3Hlgfer7dzFYAhTopQF0HIuIz4ek6yNGiWXWCd8RdGBp1g5ktNbMJZlY+v/cr0YqIzxQ80ZpZspl9mW1IPqg1s1LA28AQ51waMBaoBTQlWPE+kl9E6joQEX85jK4D51wKkHLopqwYwST7qnPundB7UrNNfw74ML/lqKIVEZ8J21EHBjwPrHDOPZptfOVss3UBluUXkSpaEfGZsB11cBZwOfCdmS0Jjbsd6G1mTQEHrAKuy68hJVoR8ZfwHXUwl9yz9seH25YSrYj4TOz1iCrRioi/xOApuEq0IuIzsZdoY6/GFhHxGVW0IuIv6joQEfGaEq2IiLcs9npElWhFxGdU0YqIeEyJVkTEW9oZJiLiNSVaERGPKdGKiHhLXQciIl5TohUR8ZgSrYiIt9R1ICLiNSVaERFvxeApuLEX0RG4bdSjnNmuF//qOSDP+ZYu/4H6LTry6WefRyiywrntrgc587wu/Kvb1YecZ+HiJXTu2Z+OXa/isn43Hlb78+Z/SdfeyXTqfg1deyczf9HXWdP6/ftWLu7Zj45dr2LkvY8SCAQA+OR/s+jY9SrqNUviu+U/FG7FYsCceYto3/kKzu/Ul5QJr0U7nIhYv2ETlw+4gwt7XE/Hnjcw8fUPcp1v4Vff0bnPEDr2vIHLkm+PcJThEJ6bM4aTryrarp3O57KeFzP8rocPOU8gEODhJ1/g7DP/GcHICqfrxR24rFcXht/5QK7T09LSGfXA44x/5kGqVE7kzy1bD6v98uXLMvaJ+0mseAI//vwb/QbeyufT3gLgiYfuolSpkjjnGHzLXXw6bTYdOyRxSu0aPPXo3dx1z6P5tB67AoEAdz/wBC/8dwyJiQl07zuApDatqF2rerRD81Rc0ThGDLmGhvVqkb5jJ92uuJmzzmhC7ZonZ82Ttj2dUQ/+l/FP/ocqlRL4c8u2KEZcWOFJoGZWFXgJqARkAinOuSfMrALwJlCd4M0Zezrn8vzn81VF2+KfjSlbtnSe87z85vu0b3sWx5cvF6GoCq/FaU0oW6bMIad/8MlnnJ90DlUqJwJwfIXyWdOmfDSN7n0H0rlnf0be80hWRZpdg3p1SKx4AgB1alVn79697N27F4BSpUoCkJERYN++jKz9C7VqVqNm9ZMPautosnTZSqpVrULVk6pQvFgxOrZPYvqsedEOy3MVT6hAw3q1AChVsgQ1q59E6qYtOeb54NM5nH/emVSplADA8RVi///kIGYFH/KWAdzsnKsPtASuN7MGwAhgunOuDjA99DpPhU60Znbo37MxKnXjZj6b+QW9ul0U7VDCYtXva0hL287l/YbQtXcy730wFYBffv2dT6bO5PUXn2LKpPEUKVKEDz7+LM+2pn42h/r1alO8ePGscf0GDqNVUhdKloinfbs2nq5LJKVu3EylShWzXicmJpC6cXMUI4q8NetSWfHDrzRpeEqO8atWryMtLZ3Lr7uDrpcP5b2PZkQpwuhzzq13zn0der4dWAGcCHQGJoZmmwhckl9bR9J1MAp4IbcJZpYMJAOMGzeO5D7tjmAx4XPfw+O4ZfA1xMXFRTuUsAgEAixf8SMvpjzC7t176XXF9TQ5tQHzF33NshU/0r1vsK969569Oardv/vp5994+IkUJox9KMf458eOYc+evdxy+70sWPQNZ53Z3NP1iRTn3EHjLAYPCfLKjp27GDz8QW4f2p9SpUrkmBYIBFi+8hdefPYedu/ZS69rbqVJo7rUqHZilKItjIL/LbPnqpAU51xKLvNVB5oBC4FE59x6CCZjM6v49/n/Ls9Ea2ZLDzUJSDzU+0KB7g/Wkf5rfnFExLIVPzH0ttEAbN2Wxux5iykaF0e781pFObLCqZSYQPlyZSkRH0+J+Hian3YqK3/4BeccXTq15+bB1+aYf9qMz3n6v8Ev4nvvGkbjhnXZkLqJG4aO5MF7RnBy1YP/mf7xj+IktWnF9FnzfJNoKyUmsGHDxqzXqambqJhwfBQjipx9GRkMHj6aTh3acEHSmQdNr1TxeMqXK0OJ+OMoEX8czZs1ZOVPq46uRHsYRx38LVfl3pxZKeBtYIhzLq0wX8r5RZQIXAF0ymX487CXFmUzPniRGR9OZMaHE2nf9mzuGnH9UZtkAdqeexZffrOUjIwAu3btZul3K6hVsxpnnv5Ppk6bnbVzbNtfaaxdt4Hzk85hyqTxTJk0nsYN65KWlk7yoBEMHdyf05o1zmp3x85dbNwU/PNmZASYPXchNWsc3f2y2TVuWI9Vq9fyx9r17N23j4+mziCpzdH7OSgo5xx33PMUNatX5eq+nXOdp22bM/jym++Dn6nde1i67EdqVT8pwpEeqfAddWBmxQgm2Vedc++ERqeaWeXQ9MrAxkO9f7/8ug4+BEo555bkEsCsfKOMsKG3j2bRl0vZui2N1hdexqDrLicjIwOA3t07Rjm6wzd0xD0s+nIJW7f9ResLejBo4FVkZAR3avXucTG1albjnFanc3HPfhQxo3uXjpxSuwYAQ264hmsGDCPTOYoVjWPkbUM4sUqlHO2/8ua7rF69jmdTXubZlJcBmPDfMTjnGHjjHezdt4/MQICWp/+TXt0vBoJV8T2jn2TL1r+4btBt1K9bi+fHjongVjlyRYvGMXLEYPoPvJVAZibdOl9IndB287Ovvl3BlI9ncUrtanTuMwSAoddfxroNmwDo3e1CatWoyjmtmnFxn8EUsSJ073w+p9SuFs2wD1+YuoEsWLo+D6xwzmU/zOZ94EpgdOhxSr5t5dZfFWYx03UQVaVqBh93rYtuHLEgvkrwUdviwLZIWxndOGJBmXoQjmOzlowoeFJrOvqQyzOzs4HPge8IHt4FcDvBftpJwMnAaqCHc25Lro2E+Oo4WhGRcB1H65ybm0djbQ+nLSVaEfGXGDyCRIlWRHxGiVZExFuqaEVEvKZEKyLiMSVaERGPKdGKiHgrBi/8rUQrIj6jilZExFs66kBExGuxl2hjrzNDRMRnVNGKiL+o60BExGux90NdiVZE/EUVrYiI15RoRUQ8pkQrIuItdR2IiHgt9hJt7O2eExE5EmYFH/JtyiaY2UYzW5Zt3H/MbK2ZLQkNF+XXjhKtiPhM+G43DrwIdMhl/GPOuaah4eP8GlGiFRGfCV+idc7NAfK8w21BKNGKiL8cRteBmSWb2ZfZhuQCLuUGM1sa6loon9/MSrQi4jMFr2idcynOuebZhpQCLGAsUAtoCqwHHsnvDTrqQER8xtv60TmXuv+5mT0HfBjdiEREIi2MRx3k3rxVzvayC7DsUPPup4pWRHwmfMfRmtnrwLnACWa2BrgLONfMmgIOWAVcl187SrQiIofgnOudy+jnD7cdJVoR8Redgisi4rVjNdGWqhmRxRwV4qtEO4LYoW1xQJl60Y7AP3S7cRERrx2rFe3mBRFZTEw7oWXwcde66MYRC/ZXstoWB7bFa7GXHCKujwtTQ7G3LVXRioi/aGeYiIjXlGhFRLwVgzvDYi8iERGfUUUrIj6jrgMREW9pZ5iIiNeUaEVEPKZEKyLirRg86kCJVkR8RhWtiIjHlGhFRLwVe3lWiVZE/Cb2Mm3s9RqLiByRgt9uPN+WzCaY2UYzW5ZtXAUzm2ZmP4Uey+fXjhKtiPiLFSn4kL8XgQ5/GzcCmO6cqwNMD73OkxKtiPhM+Cpa59wcYMvfRncGJoaeTwQuya8dJVoR8RezAg9mlmxmX2YbkguwhETn3HqA0GPF/N6gnWEi4jMF3xnmnEsBUryLJUgVrYj4TPi6Dg4h1cwqA4QeN+b3BiVaEfGX8O4My837wJWh51cCU/J7gxKtiMghmNnrwHygrpmtMbN+wGjgfDP7CTg/9DpP6qMVEX8J4/VonXO9DzGp7eG0o0QrIj4Te2eGKdGKiM8o0YqIeEu3shER8Vrs7eNXohURn1FFKyLirRjsOoi9GltExGdU0YqIz6ii9VxSt5vpdPkddL7y/+h6zV2HnG/pil+pf85VfDpzcQSjC59AIMAll17LdYNuK/B71m/YyOX9b+LCLlfSsetVTHx1cta0x5+ZQKce/ejcsz/XDBhG6sbNAPzy22ouveJ6GrW4gOcnvhn29YiWOfMW0b7zFZzfqS8pE16LdjhREciES8adzHWvVQHgj61F6TG+Khc8VZ0hkyuzNxDlAAvrMK7eFSm+rGgnPjWCCuVKH3J6IJDJw89O4uzTG0cwqvB66bW3qVXjZNJ37Czwe+Li4hhx80Aa1j+F9B076db7Os5q2ZzatarT/8pLGXL9NVltP5PyEnffOZRyZUtzx62DmD5zrlerEnGBQIC7H3iCF/47hsTEBLr3HUBSm1bUrlU92qFF1EsLy1HrhL2k7wnWWw9/lsBVLbfRsdF2Rn5Ykclfl6VPi7+iHGVhxF79GHsRRcDLk6fR/tzmHF++TLRDKZQNqZuY9fkCunftmDVu2fc/cFm/G+naO5l+A4excdOfB72vYsLxNKx/CgClSpagZs2TsyrXUqVKZs23a9duLPRtf3yF8pzaqB5Fi/rnO3npspVUq1qFqidVoXixYnRsn8T0WfOiHVZEbUgryqyfStH9n8FE6hws+K0E7RtsB6BLkzSm/1AqmiEWXgxWtPkmWjOrZ2ZtzazU38b//fYOscGg301j6HrNSN6cMvOgyambtvDZnK/odUlSFIILj/vHPM2wIddRJHT1oX37Mrh39FM8OWYU77yeQrdLLuSxp8fn2caatRtYsfJnmjSunzXusafG06Z9Tz74+DNuHHi1p+sQTakbN1Op0oFrNScmJmR94Rwr7v80gWHtNlEklGu27ipCmeMCFA1lhEplMkhNO1q/XD2/TOJhyzPRmtlggpcAGwQsM7PO2Sbf72VghfX62Dt594W7ee6RW3j1neksXrIyx/T7nniNWwb2JC7u6CzmZ86ZT4Xy5WjUoG7WuN9+/4Mff/mNqwfcQuee/Rn73Cukph46cezYuYvBt4zk9mHX56hkbxrUn9lTJ9Hpona88sa7nq5HNDnnDhpnMXhIkFdm/liSCiUDNKqy58BId/D6H72bJPYSbX5fWdcCpznn0s2sOjDZzKo7554gjyhDt4NIBhg3bhzJXU8NU7j5S0wI3pDy+PJlOL/1aSz9/ldaNK2XNX3Zyt8YetdYALb+tZ3Z87+laFwR2rU+LWIxHomvlyxjxuwvmDN3IXv27iV9x06eGvsCdWpV582Xnskx7/oNGxkw+HYAevW4mN49LmbfvgwG3zySThe144K2rXNdxr8ubMt1g25j8L/9WdVWSkxgw4YD12pOTd1ExYTjoxhRZH29Op4ZP5Rkzk812JNhpO8pwn1TE0jbHUdGJhQtEuxaqFg6I9qhFk4MfkPkl2jjnHPpAM65VWZ2LsFkW408Eu3fbg/h2LwgHLHma+euPWRmZlKqZDw7d+1h3qJl/PvqzjnmmTH5kaznI+59jnPPanrUJFmAmwdfy82DrwVg4eIlTHjpTR4Z/X907HoV33y7nGZNGrJvXwarfv+DOrVrMGXSgS4E5xx3jHqImjWqcfXlPXO0u+r3NVSvdhIAM2Z/Qc0aJ0dupSKsccN6rFq9lj/Wriex4gl8NHUGj9x/Z7TDipib223m5nbBXzwLV8Uz4YvyPNJ1A4PfqszU70vTsdF23v22DEl106McaWHF3q/V/BLtBjNr6pxbAhCqbP8FTABibpf9n1v+4vrbnwQgkBHgXxecSeuWp/L6uzMA6N3l6O2XzUvxYsV4cswo7n3oKbanpxPICHBl3+7UqV0jx3xfLVnGlA+ncUqdmnTu2R+AoYP60+acljzyZAq/rfoDK1KEEysnMuqOmwDYtHkL3fpcR/qOnRQxY+Krk/n4nRdzdDkcbYoWjWPkiMH0H3grgcxMunW+8KBtdSwa1m4zN02uzOMzjqd+5T30aJYW7ZAKJ/YKWiy3/qqsiWYnARnOuQ25TDvLOVeQXbURq2hj2gktg4+71kU3jlgQHzxuU9uCA9vitRjMDpHWx0E40uSmuYdOan+XcHZENnyeFa1zbk0e046t42FE5CgRe19aR+vxGyIiuQvjzjAzWwVsBwIEf903L0w7SrQi4jNhr2jPc84d0YHWSrQi4i+Fv424Z2IvIhGRIxLWExYc8D8z+yp0fkChqKIVEZ8peNdB9pOrQlJC5wHsd5Zzbp2ZVQSmmdlK59ycw41IiVZEfKbgifZvJ1flNn1d6HGjmb0LnA4cdqJV14GI+EuYeg7MrKSZld7/HLgAWFaYkFTRiojPhO2og0Tg3dAFh4oCrznnPi1MQ0q0IuIvYTrqwDn3K9AkHG0p0YqIz+jMMBERjynRioh46yi8Hq2IyFFGiVZExFuqaEVEvKZEKyLiMSVaERFvqetARMRrSrQiIh5TohUR8VYMXvhbiVZEfEYVrYiIt2JwZ1js1dgiIj6jilZEfCb2KlolWhHxGSVaERFv6agDERGvqaIVEfGWjjoQEfFamG6DC5hZBzP7wcx+NrMRhY1IiVZEfCY8idbM4oBngAuBBkBvM2tQmIgi03VwQsuILOaoEF8l2hHEDm2LA/q4aEfgH+HrOjgd+Dl0N1zM7A2gM/D94TYUiUQbEx0mZpbsnEuJdhyxQBbth2sAAAJkSURBVNviAG2LA3yzLeKrFDjnmFkykJxtVEq2bXAi8Ee2aWuAMwoT0rHUdZCc/yzHDG2LA7QtDjjmtoVzLsU51zzbkP2LJreEXaifHsdSohURORxrgKrZXp8ErCtMQ0q0IiK5WwzUMbMaZlYc6AW8X5iGjqXjaI/+vqfw0bY4QNviAG2LbJxzGWZ2AzAViAMmOOeWF6Ytc057O0VEvKSuAxERjynRioh4zPeJNlyn0PmBmU0ws41mtizasUSTmVU1s5lmtsLMlpvZjdGOKVrM7DgzW2Rm34a2xahox+RHvu6jDZ1C9yNwPsFDNRYDvZ1zh31mhx+YWWsgHXjJOdco2vFEi5lVBio75742s9LAV8Alx+LnwswMKOmcSzezYsBc4Ebn3IIoh+Yrfq9os06hc87tBfafQndMcs7NAbZEO45oc86td859HXq+HVhB8CygY44LSg+9LBYa/Ft9RYnfE21up9Adk/9Qkjszqw40AxZGN5LoMbM4M1sCbASmOeeO2W3hFb8n2rCdQif+Y2algLeBIc65tGjHEy3OuYBzrinBM59ON7NjtlvJK35PtGE7hU78JdQf+TbwqnPunWjHEwucc9uAWUCHKIfiO35PtGE7hU78I7QD6HlghXPu0WjHE01mlmBm5ULP44F2wMroRuU/vk60zrkMYP8pdCuASYU9hc4PzOx1YD5Q18zWmFm/aMcUJWcBlwNJZrYkNFwU7aCipDIw08yWEixMpjnnPoxyTL7j68O7RERiga8rWhGRWKBEKyLiMSVaERGPKdGKiHhMiVZExGNKtCIiHlOiFRHx2P8DIAoAnjMwDk4AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 432x288 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "ax = sns.heatmap(applied_attention,\n",
    "                 annot=True,\n",
    "                 cmap=sns.light_palette(\"orange\", as_cmap=True),\n",
    "                 linewidths=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-23T08:43:05.003499Z",
     "start_time": "2020-04-23T08:43:04.997068Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([ 3.88079708,  4.0728263 , 45.26423912], dtype=float128)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 计算 attention 的输出\n",
    "def calculate_attention_vector(applied_attention):\n",
    "    return np.sum(applied_attention, axis=1)\n",
    "\n",
    "attention_vector = calculate_attention_vector(applied_attention)\n",
    "attention_vector"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-04-23T08:43:23.615316Z",
     "start_time": "2020-04-23T08:43:23.530574Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<matplotlib.axes._subplots.AxesSubplot at 0x7fb88d8a3bd0>"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAIYAAAEWCAYAAACjaO9mAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4xLjMsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+AADFEAAAPVUlEQVR4nO3dfZBU1ZnH8e8zA6yUYAyLCygxptgUhHUjlkq0qCiBXYOBKO7qJuTdSg1uajWk1hdEa9ekShLNC5gqqzaiQUmKBGOIlYTaNUuhgME3VgREQGMAE0RCNhE3vjGZ4dk/7kV6mjNzp3vm3nO75/ep6mL6dvftU8NvnnPv7T7nmLsjUq0ldgOknBQMCVIwJEjBkCAFQ4IUDAlSMJqQmbWa2dNmtiq9f6+Z7TazzeltUtY+BuXfTIlgHrADOKFi23Xu/uPe7qCIiuG6vX3rwgyv9Zb1yzazscBM4O6s5/akkIqxb18R71JuJ59c2FvdDlwPDK/avtDM/h1YA9zg7od62omOMSIyq+dmc83sfypuc4/uz2YBB9z9qaq3WgBMAM4BRgDzM9tWwGclrorxdsWwym2trdldQ7XOzq77qGRmXwM+DXQAx5EcY/zE3T9V8ZypwLXuPqun91HFiKieitETd1/g7mPd/TTg48BD7v4pMxuTvJ8ZMBvYltU2nZVE1FLcn+VyMzuJpGJtBv456wXqSgoS6kqGDKm9K2lv774r6U+qGBFldQ0xKRgRKRgSpGBIkIIhQQqGBCkYEqRgSJCCIUFlDoY+K5EgVYyIylwxFIyIFAwJUjAkSMGQIAVDghQMCVIwJEjBkCAFQ4IUDAkqczD0WUlE/T2u5Oh+jxnt/h4ze8LMfmVm95nZkKx9KBgR5RUMjo52P+I2YLG7vxd4Bfh81g4UjIjyCEb1aPd09Nk04MgUCMtIRqP1SMGIqL8HNaeOjHY/nN7/S+Cgu3ek9/cCp2S1TQefEdVz8OnuS4Al4f0dHe2eDl4GgiPXMkfAKRgR5XBWMgW42Mw+wtHR7rcDJ5rZoLRqjAUyB42qK2ki3Yx2/yTwMHBZ+rTPAj/N2peCEVGOZyXV5gP/amYvkBxzfDfrBU3TlbS3H2LevE/S3t5OZ2cnF1zwYa644otdnrN//0t8/es38uqrf2T48BO56aZvcNJJoyO1ON8LXO6+Flib/rwLmFzL65smGIMHD2HRomUMHXo8HR1/5uqrP8EHPnA+EycenbnwO9+5jQsvnM2MGZeyadNj3HXXt7jxxm9Ea7OufBbAzBg69HgAOjo66OzsoPqAfM+eX3PWWecBcOaZ57Jhw5qim9lFgV1JzTIrhplNAC4hOfd1kiPan7n7jh5fGEFnZydXXvkPvPTSb5g9+xNMnHhGl8fHjZvAunW/4LLLPssjj6zmjTde59VXX+Ed73hnlPY2bMUws/nACpI/vSeBjenPPzSzG/JvXm1aW1u5++6fcv/969i5cyu7dz/f5fEvfOF6tm7dSFvbbLZseZKRI0fR2hqvNy1zxehxqiUzex74G3f/c9X2IcCz6bX30OvmAnMB7rzzzrNmzaq+OJe/Zcvu4LjjhvKxj4U/Fnjzzdf5zGcu4v771xfSntBUS+PH1z7V0nPPlWOqpcPAycCLVdvHcPSS6zGqrs4VMgfXwYN/ZNCgQQwbdgKHDr3FU089ypw5bV2ec+RspKWlheXLl3DRRf+Yf8N6UOauJCsYXwLWmNmvgN+m204F/hq4Ks+G1eoPfzjArbfewOHDnRw+7EydOoPzzvsQS5d+m/HjT2fKlOls3vwkd921CDPj/e8/m3nzbo7a5jIHI3PWPjNrITkHPoWkFO4FNrp7Zy/fQ7P2Ee5KJk6svSvZvr0cXQnufhh4vIC2SIk0zQWuRlTmrkTBiKjAmYFrpmBEpIohQQqGBCkYEqRgSFCZg1Hi42KJSRUjojJXDAUjIgVDgsocDB1jRNTfX9Qxs+PM7Ekz22Jmz5rZV9Lt92oJ7waSQ8U4BExz99fMbDDwSzP7r/SxmpbwVjAi6u9gePIditfSu4PTW12rIaoriSin0e6tZrYZOACsdvcn0ocWmtlWM1tsZn+RtR8FI6I8Rru7e6e7TyIZozrZzE6njiW81ZVE1N+j3aued9DM1gIz3P2b6eZDZnYPcG3W61UxmoiZnWRmJ6Y/DwX+DtipJbwbTA5nJWOAZWbWSvJH/yN3X2VmD9W6hLeCEVEOZyVbgTMD26fVui8FI6IyX/lUMCJSMCRIwZAgfUtcglQxJEjBkKAyB6PEvZzEpIoRUZkrhoIRkYIhQQqGBCkYEqRgSJCCIUEDPhjpxGRSZcAHQ8IGfDA0nWPjVU1VjIgGfMWQMAVDgsocDH26GlGBo921hHcjyWHs6pHR7mcAk4AZZnYuWsK7sfR3MDwRGu2uJbwbSUtL7besQc3Vo92BX6MlvBtLHoOa0+VCJqVjWB8A3hd6Wtb7qGI0KXc/SLLu6rmkS3inD2kJ77LL4awkNNp9B3Us4a2uJKICR7tvB1aY2S3A0wykJbwbUYGj3QfuEt6NqMxXPhWMiBQMCVIwJEjBkKAyB0PXMSRIFSOiMlcMBSMiBUOCFAwJUjAkSMGQIAVDghQMCVIwJKjMwdCVTwlSxYhIU0ZLUJm7EgUjIgVDgsocjBL3cs0vh+ED7zKzh81sRzqoeV66/ctm9lLFEt4fyWqbKkZEOVSMDuAad99kZsOBp8xsdfrY4oplNjM1XcXo7OykrW02CxZcecxjW7ZsZO7cS5k+fSLr1j0YoXVd5TCo+WV335T+/CeSwUaZ41RDmi4YK1d+j1NPHRd8bNSoMcyf/zWmT59VcKvCcpgGoWLfdhrJGJMjS3hflS7hvdTM3pn1+qYKxu9/v5/HH1/LzJmXBR8fPXos48ZNoKUkFxDqCUbWaPdkvzYMWAl8yd3/D/gPYBzJnBkvA9/KalvdxxhmdoW731Pv6/Nwxx1f5corr+PNN1+P3ZTcZI12N7PBJKFY7u4/SV/zu4rH7wJWZb1PX/50vtJD495O9ZIlmcuQ94vHHnuYE08cwfjxpxfyfv0hh7MSIxmXusPdF1VsH1PxtEvp6xLeZra1u4eAUd29rirVXsQ8n9u2beLRRx/iiSfW095+iDfeeI2FC6/lppt6fSBeuBzOSqYAnwaeSSdPAbgRmGNmk0jmxdgDHHtkXiWrKxkFfJhk3qZKBjxaQ4Nz19Z2DW1t1wCwefMT3Hff0lKHAnIZ1PxLkv+bav9Z676yupJVwDB3f7HqtodkUo7SW7r022zYsAaAnTu3cvnl57Nu3YMsWnQzn/vczKhty/OspM9tc8+cdaevCulKyi6dMrrLf+38+dlTHlW77bZgReh3uvIZUZk/K1EwIlIwJEjBkKAyB6Mc14aldFQxIipzxVAwIirJZ3lBCkZEqhgSpGBIkIIhQQqGBCkYElTmYJT4hEliUsWIqMwVQ8GISMGQIAVDghQMCSpzMHRWElGBo91HmNnqdAnv1QNuiGKjyeFb4kdGu7+PZFnNfzGzicANwJp0Ce816f0eKRgRFTja/RKSpbtBS3gPbFWj3Ue5+8uQhAf4q6zXKxgRFTjavWY6K4monrOSeka7A78zszHu/nI6wPlA1vuoYkRU1Gh34GckS3eDlvAuvwJHu98K/MjMPg/8Brg8a0cKRkQFjnYHmF7LvhSMiPQtcQkq8yVxBSOiMgejxMVMYlLFiKjMFUPBiEjBkKABH4x0/impMuCDIWEDPhhl/gUUJTQ5Ypl/L6oYESkYEqRgSJCCIUFlDoYuiUuQKkZEZa4YCkZECoYEKRgSpGBIkIIhQWUOhk5XI2ppqf2WJV1w94CZbavY9uVa13ZXMJrPvcCMwPbF7j4pvWUunqeuJKI8uhJ3X58OaO4TVYyI8hrU3I2Bu7Z7o6knGO6+xN3Prrj1Zins4tZ2l74r6qyk6LXdpY9ymGqpm/fp57XdJV95VAwz+yEwFRhpZnuBm4Gp/b22u+Qop7OSOYHN3611PwpGRGW+8qlgRKRgSFCZg6GzEglSxYiozBVDwYhIwZAgBUOCFAwJUjAkSMGQIAVDgsocDF3gkiBVjIjKPGV0iZvWNy0tsGkT/Pznyf177oFdu+Dpp5PbGWfEbR8U90WdejRtxZg3D3bsgBNOOLrtuutg5cp4baqmY4yCnXIKzJwJd98duyU9K3PFyAyGmU0ws+npOluV20ODWkrh9tvh+uvh8OGu2xcuhC1bYNEiGDIkTtsqNWwwzOyLJMskXQ1sM7NLKh7+ap4Nq9fMmXDgQHJ8UWnBApgwAc45B0aMgPnz47SvUsMGA2gDznL32SRfMP23I6v/Qrcr6XQZFLNkSW+GPfSfKVPg4oth925YsQKmTYPvfx/2708eb29PDkQnTy60WUFlDoZ5aGbStxtu2919YsX9YcCPge3ANHef1Iv38FgHWRdcANdeCx/9KIwefTQcixfDW28lVaQo6a+5y29i/Xq6/+V34/zzu/+DBDCzpcAs4IC7n55uGwHcB5xG8i3xf3L3V3raT1bF2J9+7RwAd38tfdORwN9mvLZUli+HrVvhmWdg5Ei45ZbYLcqtYtzLsYOaa17CO6tijAU63H1/4LEp7r6hFw2NVjHKJFQxHnmk9orxwQ/2XDEA0kHNqyoqxnPA1Ip1V9e6+/ie9tHjdQx339vDY70JhfSgwD+YLkt4m5mW8C6zerqSPox2r0nTXvlsBPVUjKwlvLuhJbwbSYGnq1rCu5EUOKhZS3g3kgIHNYOW8G4cZT6NVzAiUjAkqMzB0FmJBKliRFTmiqFgRFTmLwMrGBGpYkiQgiFBCoYEKRgSVOZglPi4WGJSxYiozBVDwYhIwZAgBUOCFAwJUjAkSMGQIAVDghQMCVIwpDBmtgf4E9BJMu747Hr2o2BElGPF+JC7/29fdqBgRFTmrkQfokWU06BmB/7bzJ7qy4DnHufH6CeaH4Pw/Bgvvlj7/BjvfnfmjDonu/u+dKqD1cDV7r6+1vdRxYgoj0HN7r4v/fcA8ABQ12xjCkZELS2133piZseb2fAjPwMX0ovlukN08BlRDl3sKOABS3Y8CPiBuz9Yz44UjIj6Oxjuvgvol1nSCwlG/se30t+KOCspBTObm05TJL0wkA4+c5nErFkNpGBIDRQMCRpIwdDxRQ0GzMGn1GYgVQypQdMHw8xmmNlzZvaCmWXOui+Jpu5KzKwVeB74e2AvsBGY4+7bozasATR7xZgMvODuu9y9HVgBXJLxGqH5g3EK8NuK+3vTbZKh2YMR+piqefvOftTswdgLvKvi/lhgX6S2NJRmD8ZG4L1m9h4zGwJ8nGSJBsnQ1N/HcPcOM7sK+AXQCix192cjN6shNPXpqtSv2bsSqZOCIUEKhgQpGBKkYEiQgiFBCoYEKRgS9P9ZjbsFyVcUZgAAAABJRU5ErkJggg==\n",
      "text/plain": [
       "<Figure size 108x324 with 2 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# 可视化 attention 输出\n",
    "plt.figure(figsize=(1.5, 4.5))\n",
    "sns.heatmap(np.transpose(np.matrix(attention_vector)),\n",
    "            annot=True,\n",
    "            cmap=sns.light_palette(\"Blue\", as_cmap=True),\n",
    "            linewidths=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T15:41:53.807016Z",
     "start_time": "2020-05-05T15:41:53.801668Z"
    }
   },
   "source": [
    "## `PyTorch`实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T15:44:21.168383Z",
     "start_time": "2020-05-05T15:44:21.161619Z"
    }
   },
   "outputs": [],
   "source": [
    "import torch, math, torch.nn.functional as F\n",
    "\n",
    "\n",
    "def attention(query, key, value, mask=None, dropout=None):\n",
    "    \"\"\"\n",
    "    query : batch, target_len, feats\n",
    "    key   : batch, seq_len,    feats\n",
    "    value : batch, seq_len,    val_feats\n",
    "    \n",
    "    return: batch, target_len, val_feats\n",
    "    \"\"\"\n",
    "    d_k = query.size(-1)\n",
    "    scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)\n",
    "\n",
    "    if mask is not None:\n",
    "        scores = scores.masked_fill(mask == 0, -1e9)\n",
    "    p_attn = F.softmax(scores, dim=-1)\n",
    "\n",
    "    if dropout is not None:\n",
    "        p_attn = dropout(p_attn)\n",
    "    return torch.matmul(p_attn, value), p_attn"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2020-05-05T15:44:22.706233Z",
     "start_time": "2020-05-05T15:44:22.698524Z"
    }
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([3, 5, 8])\n",
      "Test passed\n"
     ]
    }
   ],
   "source": [
    "def test_attention():\n",
    "    query = torch.randn(3, 5, 4)  # batch, target_len, feats\n",
    "    key = torch.randn(3, 6, 4)  # batch, seq_len, feats\n",
    "    value = torch.randn(3, 6, 8)  # batch, seq_len, val_feats\n",
    "    attn, _ = attention(query, key, value)\n",
    "    print(attn.shape)\n",
    "    assert attn.shape == (3, 5, 8)\n",
    "    print(\"Test passed\")\n",
    "\n",
    "\n",
    "test_attention()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## `TensorFlow`实现"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "\n",
    "\n",
    "def scaled_dot_product_attention(q, k, v, mask):\n",
    "    \"\"\"\n",
    "    :param q: (batch_size, seq_len_q, depth)\n",
    "    :param k: (batch_size, seq_len_k, depth)\n",
    "    :param v: (batch_size, seq_len_v, depth_v)\n",
    "    :param mask: (batch_size, seq_len_q, deq_len_k)\n",
    "    :return:\n",
    "    \"\"\"\n",
    "    matmul_qk = tf.matmul(q, k, transpose_b=True)\n",
    "    # (..., seq_len_q, seq_len_k)\n",
    "\n",
    "    dk = tf.cast(tf.shape(k)[-1], tf.float32)\n",
    "    scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)\n",
    "\n",
    "    print(\"shape of mask:\", mask.shape)\n",
    "    print(\"shape of atten:\", scaled_attention_logits.shape)\n",
    "    if mask is not None:\n",
    "        scaled_attention_logits += (mask * -1e9)\n",
    "        # 在待遮挡处减去一个较大值，softmax 即可忽略该位置\n",
    "\n",
    "    attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)\n",
    "    # (..., seq_len_q, seq_len_k)\n",
    "\n",
    "    output = tf.matmul(attention_weights, v)\n",
    "    # (..., seq_len_q, depth_v)\n",
    "\n",
    "    return output, attention_weights"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
